Welcome to CodeRATS's Plotly workshop! This consists of two parts:
Plotly makes figures in 4 steps:
There are two versions of Plotly:
We will start off with using Plotly Express to make a scatter plot
# import packages
import pandas as pd
import plotly.express as px
# Step 1: load your data
df = px.data.iris() # reads in data as pandas dataframe (like a table)
display(df)
| sepal_length | sepal_width | petal_length | petal_width | species | species_id | |
|---|---|---|---|---|---|---|
| 0 | 5.1 | 3.5 | 1.4 | 0.2 | setosa | 1 |
| 1 | 4.9 | 3.0 | 1.4 | 0.2 | setosa | 1 |
| 2 | 4.7 | 3.2 | 1.3 | 0.2 | setosa | 1 |
| 3 | 4.6 | 3.1 | 1.5 | 0.2 | setosa | 1 |
| 4 | 5.0 | 3.6 | 1.4 | 0.2 | setosa | 1 |
| ... | ... | ... | ... | ... | ... | ... |
| 145 | 6.7 | 3.0 | 5.2 | 2.3 | virginica | 3 |
| 146 | 6.3 | 2.5 | 5.0 | 1.9 | virginica | 3 |
| 147 | 6.5 | 3.0 | 5.2 | 2.0 | virginica | 3 |
| 148 | 6.2 | 3.4 | 5.4 | 2.3 | virginica | 3 |
| 149 | 5.9 | 3.0 | 5.1 | 1.8 | virginica | 3 |
150 rows × 6 columns
Let's compare petal width (petal_width) to petal length (petal_length) in a scatter plot. Run the following code block, and notice that you made a scatter plot in two lines of code!
Try out different viewing methods:
# 1. Make the figure
fig = px.scatter(df, # df: a pandas dataframe with the data you want to plot. Rows are "samples" and columns are "features" you want to plot, categorize, etc.
x="petal_width", # x: name of column in df you want to use in the x-axis of the scatter plot
y="petal_length") # y: name of column in df you want to use in the y-axis of the scatter plot
# 2. Show the figure in the notebook
fig.show()
Say you want to color these points by species. It is as simple as including the color= parameter into px.scater(...). plotly.express will automatically color the points according to the data under the species column in your dataframe, and include a legend.
Try:
color, use symbol to discriminate the points coming from different species. symbol_sequence or symbol map - can you figure out the difference between the two arguments?hover_data=["sepal_length", "sepal_width"] into px.scatter(...) so when you hover over the points, you can also see the sepal_length and sepal_width values for that point.marginal_y="violin" and marginal_x="box" into px.scatter(...). More documentation and options described at https://plotly.com/python/marginal-plots/# 1. Make the figure
fig = px.scatter(df,
x="petal_width",
y="petal_length",
color='species',
symbol='species',
symbol_map={'setosa':101, 'versicolor':2, 'virginica':203},
hover_data=["sepal_length", "sepal_width"],
marginal_y="violin",
marginal_x="box",
title='Comparing Petals Across Species',
labels={'petal_length': 'Petal Length', 'petal_width': 'Petal Width'}
)
# 2. Show the figure in the notebook
fig.show()
Try some of the other plotly.express plots! View some of the basic options:
Plotly express is great for quickly making a plot for exploring your data. Sometimes, you need a cleaner figure, or you just want to plot something more complicated. Plotly Graph Objects allows you to work directly with plotly figure components so you can easily customize.
Let's make a scatter plot again, but this time with plotly graph objects. We will be loosely following this tutorial: https://towardsdatascience.com/tutorial-on-building-professional-scatter-graphs-in-plotly-python-abe33923f557
# import the plotly graph objects package
import plotly.graph_objects as go
Step 1: Initialize your figure and add data to your plot. To do this, we use add_trace(...). A trace is like a layer of data (or a graph object) to add to the figure. You can call fig.add_trace(...) multiple times to add multiple traces (say a scatter plot overlaying a bar plot). Other helper methods exist such as add_shape(...) and add_hline(...); we will get to those later.
For now, we will only add one trace, which will be a scatter plot. To make this, we will call go.Scatter(...) to make a scatter plot graph object. Then we will add this scatter graph object to our figure my_fig using my_fig.add_trace(...).
# Step 1: make the figure
my_fig = go.Figure()
# Step 2: Add data.
scatter_graph_object = go.Scatter(x = df["petal_width"],
y = df["petal_length"],
# mode can be one of "markers", "lines", "lines+markers", "lines+markers+text"
# what do the others look like?
mode = "markers",
)
print(f'scatter_graph_object is a {type(scatter_graph_object)}')
scatter_graph_object is a <class 'plotly.graph_objs._scatter.Scatter'>
my_fig.add_trace(scatter_graph_object)
my_fig.show()
Note that the graph object does not take in a pandas dataframe (like in plotly express); instead the data is defined directly as an array (or a pandas Series). This is slightly inconvenient but also much more flexible. Also note that unlike plotly express, plotly graph objects do not generate a graph title or axis title. These need to be explicitly defined.
To do this, call update_layout(...). Check out the documentation at https://plotly.com/python/figure-labels/ and https://plotly.com/python/axes, and make the following updates:
All the CSS named colors can be found at https://developer.mozilla.org/en-US/docs/Web/CSS/color_value
my_fig.update_layout(
title={'text':'Iris Petal Sizes',
'font_size':30,
'xanchor':'center',
'x':0.25,
'font_color': 'gray'},
plot_bgcolor = "white"
)
my_fig.update_xaxes(color='gray',
linecolor='gray',
title={'text': 'Petal Width',
'font_size': 16})
my_fig.update_yaxes(color='gray',
linecolor='gray',
title={'text': 'Petal Length',
'font_size': 16})
my_fig.update_traces(marker_color='darkslateblue')
my_fig.show()
Let's look at the underlying data structure of a graph_object Figure. Try examining my_fig.data and my_fig.layout
my_fig.data
(Scatter({
'marker': {'color': 'darkslateblue'},
'mode': 'markers',
'x': array([0.2, 0.2, 0.2, 0.2, 0.2, 0.4, 0.3, 0.2, 0.2, 0.1, 0.2, 0.2, 0.1, 0.1,
0.2, 0.4, 0.4, 0.3, 0.3, 0.3, 0.2, 0.4, 0.2, 0.5, 0.2, 0.2, 0.4, 0.2,
0.2, 0.2, 0.2, 0.4, 0.1, 0.2, 0.1, 0.2, 0.2, 0.1, 0.2, 0.2, 0.3, 0.3,
0.2, 0.6, 0.4, 0.3, 0.2, 0.2, 0.2, 0.2, 1.4, 1.5, 1.5, 1.3, 1.5, 1.3,
1.6, 1. , 1.3, 1.4, 1. , 1.5, 1. , 1.4, 1.3, 1.4, 1.5, 1. , 1.5, 1.1,
1.8, 1.3, 1.5, 1.2, 1.3, 1.4, 1.4, 1.7, 1.5, 1. , 1.1, 1. , 1.2, 1.6,
1.5, 1.6, 1.5, 1.3, 1.3, 1.3, 1.2, 1.4, 1.2, 1. , 1.3, 1.2, 1.3, 1.3,
1.1, 1.3, 2.5, 1.9, 2.1, 1.8, 2.2, 2.1, 1.7, 1.8, 1.8, 2.5, 2. , 1.9,
2.1, 2. , 2.4, 2.3, 1.8, 2.2, 2.3, 1.5, 2.3, 2. , 2. , 1.8, 2.1, 1.8,
1.8, 1.8, 2.1, 1.6, 1.9, 2. , 2.2, 1.5, 1.4, 2.3, 2.4, 1.8, 1.8, 2.1,
2.4, 2.3, 1.9, 2.3, 2.5, 2.3, 1.9, 2. , 2.3, 1.8]),
'y': array([1.4, 1.4, 1.3, 1.5, 1.4, 1.7, 1.4, 1.5, 1.4, 1.5, 1.5, 1.6, 1.4, 1.1,
1.2, 1.5, 1.3, 1.4, 1.7, 1.5, 1.7, 1.5, 1. , 1.7, 1.9, 1.6, 1.6, 1.5,
1.4, 1.6, 1.6, 1.5, 1.5, 1.4, 1.5, 1.2, 1.3, 1.5, 1.3, 1.5, 1.3, 1.3,
1.3, 1.6, 1.9, 1.4, 1.6, 1.4, 1.5, 1.4, 4.7, 4.5, 4.9, 4. , 4.6, 4.5,
4.7, 3.3, 4.6, 3.9, 3.5, 4.2, 4. , 4.7, 3.6, 4.4, 4.5, 4.1, 4.5, 3.9,
4.8, 4. , 4.9, 4.7, 4.3, 4.4, 4.8, 5. , 4.5, 3.5, 3.8, 3.7, 3.9, 5.1,
4.5, 4.5, 4.7, 4.4, 4.1, 4. , 4.4, 4.6, 4. , 3.3, 4.2, 4.2, 4.2, 4.3,
3. , 4.1, 6. , 5.1, 5.9, 5.6, 5.8, 6.6, 4.5, 6.3, 5.8, 6.1, 5.1, 5.3,
5.5, 5. , 5.1, 5.3, 5.5, 6.7, 6.9, 5. , 5.7, 4.9, 6.7, 4.9, 5.7, 6. ,
4.8, 4.9, 5.6, 5.8, 6.1, 6.4, 5.6, 5.1, 5.6, 6.1, 5.6, 5.5, 4.8, 5.4,
5.6, 5.1, 5.1, 5.9, 5.7, 5.2, 5. , 5.2, 5.4, 5.1])
}),)
my_fig.layout
Layout({
'plot_bgcolor': 'white',
'template': '...',
'title': {'font': {'color': 'gray', 'size': 30}, 'text': 'Iris Petal Sizes', 'x': 0.25, 'xanchor': 'center'},
'xaxis': {'color': 'gray', 'linecolor': 'gray', 'title': {'font': {'size': 16}, 'text': 'Petal Width'}},
'yaxis': {'color': 'gray', 'linecolor': 'gray', 'title': {'font': {'size': 16}, 'text': 'Petal Length'}}
})
Was this what you expected? The plotly figures are built upon nested-dictionaries (and lists to store multiple traces) and can be inspected and modified just like normal dicts and lists. I don't recommend trying to create/modify a Figure from scratch; methods exist for a reason. But viewing the underlying data can be helpful to remember a attribute name or understand the current data state.
Next we want to distinguish the points by category. Plotly express does this automatically if you pass a column name to the color or symbol parameters. Using graph_objects, similarly to how you updated the marker_color above with a single value, you could pass a list instead, giving the desired color of each point (based on its category). Similar techniques could be used to change the marker size or symbol as well. However, adding in a separate trace for each category is usually easier to deal with when making other updates later on.
The approach here is to use a for loop. For each unique category, make a new trace (scatter plot graph object) with the corresponding data from your table and add it to your figure.
unique_species = pd.unique(df['species']).tolist() # Get the unique values in the species column of your data
print(unique_species)
['setosa', 'versicolor', 'virginica']
You will need to specify the colors to use to plot each category. For now, we will make a dictionary with the species as the key and the color (this time in hexcode) as the value. View more colors at https://htmlcolorcodes.com/.
Then, make the figure!
species_colors_dict = {'setosa': "#2471A3",
'versicolor': "#BA4A00",
'virginica': "#884EA0"
}
# 1. initialize figure
fig_species_colored = go.Figure()
# 2. add traces
for species in unique_species:
# Get only the rows in df where the species value matches the current species
species_df = df.loc[df['species'] == species]
# Add a scatter graph object
fig_species_colored.add_trace(go.Scatter(x = species_df["petal_width"], # Using only the data corresponding to the current species
y = species_df["petal_length"],
mode = "markers",
name = species, # label the points with the species name. Default will be trace_0, trace_2, etc.
marker = dict(color = species_colors_dict[species]) # Using the species_colors_dict to get the corresponding color for this species
)
)
#####
# 3. update figure layout
fig_species_colored.update_layout(plot_bgcolor = "white", # background color
font = dict(color = "#909497"),
title = dict(text = "Iris Petal Sizes", font_size=30),
xaxis = dict(title = "Petal Width", linecolor = "#909497"),
yaxis = dict(title = "Petal Length", tickformat = ",", linecolor = "#909497"))
Maybe it would be easier to view graph if each species had its own plot. However, we still want to be able to compare them. A subplot will allow us to arrange multiple plots on the same figure. Check out this link before continuing: https://plotly.com/python/subplots/#subplots-with-shared-yaxes
from plotly.subplots import make_subplots
#create the blank graph object with make_subplots(...) instead of go.Figure()
fig_subplots = make_subplots(rows = 1,
cols = 3, # provide the dimensions of the subplot
shared_yaxes=True,
subplot_titles= unique_species, # give each subplot a title
horizontal_spacing=0.07)
# like before, iterate through each category (species)
for i, species in enumerate(unique_species): # i = 0, 1, 2
# Get only the rows in df where the species value matches the current species
species_df = df.loc[df['species'] == species]
# Add a scatter graph object
fig_subplots.add_trace(go.Scatter(x = species_df["petal_width"], # Using only the data corresponding to the current species
y = species_df["petal_length"],
mode = "markers",
name = species, # label the points with the species name. Default will be trace_0, trace_2, etc.
marker = dict(color = species_colors_dict[species]) # Using the species_colors_dict to get the corresponding color for this species
),
row = 1,
col = i + 1
)
#####
# Update layout
fig_subplots.update_layout(plot_bgcolor = "white", # background color
font = dict(color = "#909497"),
title = dict(text = "Iris Petal Sizes", font_size=30),
)
# Update all the subplots' axes at the same time
fig_subplots.update_xaxes(title = "Petal Width", linecolor = "#909497")
fig_subplots.update_yaxes(title = "Petal Length", tickformat = ",", linecolor = "#909497")
Plot still looks a little busy with all the redundant axes labels. The legend is also redundant since we already have the data separated out. To make it look nicer, we will:
fig_subplots_clean = go.Figure(fig_subplots)
# Remove legend (set showlegend to False)
fig_subplots_clean.update_layout(showlegend = False)
#fix the x axes range because shared_xaxes can't be used
fig_subplots_clean.update_xaxes(title_text= '',
showline = False,
range = [0, 3], # [xaxis minimum, xaxis maximum]
tickvals = [0, 1, 2, 3]) # list of values to add tick marks
fig_subplots_clean.update_yaxes(title_text= '',
showline = False)
# use the add_annotations() command to generate both the x-axis and y-axis titles instead of update_axes(title = ...) and update_yaxes(title = ...)
# Allows for more precise control of placement
#x axis title
fig_subplots_clean.add_annotation(text = "Petal Width",
xref = "paper",
yref = "paper",
x = 0.5,
y = -0.13,
showarrow = False)
#y axis title
fig_subplots_clean.add_annotation(text = "Petal Length",
xref = "paper",
yref = "paper",
x = -0.08,
y = 0.5,
showarrow = False,
textangle = -90)
fig_subplots_clean.show()
Lastly, let's add a few finishing touches:
import numpy as np
fig_subplots_fancy = go.Figure(fig_subplots_clean)
# iterate through the columns to add all the points in gray
for i in range(len(unique_species)): # i = 0, 1, 2
species = unique_species[i]
# Plot all the points in dataframe in gray
fig_subplots_fancy.add_trace(go.Scatter(x = df["petal_width"], # full df
y = df["petal_length"],
mode = "markers",
name = "all_points",
marker = dict(color = "#909497"), # gray color
opacity = 0.3, # Setting this trace to be more transparent
# we can provide additional data to reference in the hover labels
customdata = np.stack((df['sepal_width'],
df['sepal_length'],
df['species']),
axis=-1),
# this template defines the structure of the hover labels
hovertemplate='Petal Width: %{x:.2f} <br>' +
'Petal Length: %{y:.2f} <br>' +
'Sepal Width: %{customdata[0]:.2f} <br>' +
'Sepal Length: %{customdata[1]:.2f} <br>' +
'<extra>%{customdata[2]}</extra>',
hoverlabel={'bgcolor': 'white'}
),
row = 1,
col = i + 1,
)
#sub-title annotation
for i, species in enumerate(unique_species):
fig_subplots_fancy.add_annotation(text = species,
xref = f'x{i+1}',
yref = "paper",
x = 20,
y = 1.02,
showarrow = False,
xanchor = "left",
font = dict(size = 14, color = "#404647")
)
#create author of the graph
fig_subplots_fancy.add_annotation(text = "Author: Conor Messer", # add your name!
xref = "paper",
yref = "paper",
x = 1.005,
y = -0.145,
showarrow = False,
font = dict(size = 12),
align = "right",
xanchor = "right")
fig_subplots_fancy.show()
Finally, let's save your figure! You can easily download your plot as a png by clicking on the camera icon on the top right. If you want your figure saved as an svg or pdf (or other image format), you will use the write_image(...) method. To utilize this functionality, you may need to install an additional dependency kaleido. There can some issues in getting this package to work, however. Let us know if you need help!
fig_subplots_fancy.write_image('final_image.svg', height=500, width=800)
Now it's your turn! Keep playing around with the iris dataset and try other plots (we would especially recommend boxplots https://plotly.com/python/box-plots/ or heatmaps https://plotly.com/python/heatmaps/).
OR
Load in your own data you want to visualize.
# Step 1: Load your data
# Step 2: Format or annotate data
# Step 3: Initialize your figure
# Step 4: Update layout
# Step 5: Annotate